{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "8d593eea-cf1d-4dd4-a7e3-efdea9c79796",
   "metadata": {},
   "source": [
    "# Train_01"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "38bebaef-baa3-4950-a5b7-e7f7cc8f60ff",
   "metadata": {},
   "source": [
    "## sklearn 对照"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "5374c57a-f0bb-4252-843a-901933dc063b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#导入模块\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn import datasets\n",
    "#k近邻函数\n",
    "from sklearn.neighbors import KNeighborsClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "e571cc6b-fb28-4639-97bb-9fe8a57d4c1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.metrics import precision_score\n",
    "from sklearn.metrics import recall_score\n",
    "from sklearn.metrics import f1_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "ca296d8d-d632-485f-8978-31c92792cc65",
   "metadata": {},
   "outputs": [],
   "source": [
    "iris = datasets.load_iris()\n",
    "#导入数据和标签\n",
    "iris_X = iris.data\n",
    "iris_y = iris.target\n",
    "#划分为训练集和测试集数据\n",
    "X_train, X_test, y_train, y_test = train_test_split(iris_X, iris_y, test_size=0.3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "153d2e0e-aa10-4cfd-b0ce-6949326ba7e4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0 2 2 0 2 2 0 1 2 0 1 0 2 1 0 2 1 2 2 2 1 2 2 0 2 2 2 1 0 0 0 0 1 2 0 2 1\n",
      " 1 1 1 0 1 2 2 2]\n"
     ]
    }
   ],
   "source": [
    "#设置knn分类器\n",
    "knn = KNeighborsClassifier()\n",
    "#进行训练\n",
    "knn.fit(X_train,y_train)\n",
    "#使用训练好的knn进行数据预测\n",
    "y_pred = knn.predict(X_test)\n",
    "# print(y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "9af318f4-31db-474d-83dd-ef82730fd251",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8666666666666667\n",
      "0.9111111111111111\n",
      "0.8666666666666667\n",
      "0.8682352941176471\n"
     ]
    }
   ],
   "source": [
    "s_accuracy = accuracy_score(y_test, y_pred)\n",
    "s_precision = precision_score(y_test, y_pred, average='weighted')\n",
    "s_recall = recall_score(y_test, y_pred, average='weighted')\n",
    "s_f1 = f1_score(y_test, y_pred, average='weighted')\n",
    "print(s_accuracy)\n",
    "print(s_precision)\n",
    "print(s_recall)\n",
    "print(s_f1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e633904c-1597-40c1-8de4-dd3a3fb1f264",
   "metadata": {},
   "source": [
    "## Spflow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "id": "5cf475ce-f3a0-43a5-b69b-ac82d6cbe6b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from spn.algorithms.LearningWrappers import learn_parametric, learn_classifier\n",
    "from spn.structure.leaves.parametric.Parametric import Categorical, Gaussian\n",
    "from spn.structure.StatisticalTypes import MetaType\n",
    "from spn.structure.Base import Context\n",
    "from spn.algorithms.MPE import mpe\n",
    "\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "id": "63fb19b2-b17c-43ce-8da6-d14c58301d9a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(105,)\n",
      "(105, 4)\n"
     ]
    }
   ],
   "source": [
    "print(np.shape(y_train))\n",
    "print(np.shape(X_train))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "id": "dc37b2d7-4a0c-4c19-9a31-0fe103beb458",
   "metadata": {},
   "outputs": [],
   "source": [
    "array_nan = np.array([np.nan]*np.shape(X_test)[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "id": "dc21609d-6b2e-4a77-a1b6-6656c3ce47b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = np.column_stack((X_train, y_train))\n",
    "\n",
    "test_classification = np.column_stack((X_test, array_nan))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "id": "77b0c8eb-2ecc-4365-b36d-b48f1750cfd8",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zhangshen/software/anaconda3/envs/PY38/lib/python3.8/site-packages/spn/structure/Base.py:157: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.\n",
      "  self.domains = np.asanyarray(domain)\n"
     ]
    }
   ],
   "source": [
    "spn_classification = learn_classifier(train_data,\n",
    "                       Context(parametric_types=[Gaussian, Gaussian, Gaussian, Gaussian, Categorical]).add_domains(train_data),\n",
    "                       learn_parametric, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "id": "030b78c6-93b6-40db-8abd-572be3670a5b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zhangshen/software/anaconda3/envs/PY38/lib/python3.8/site-packages/spn/structure/leaves/parametric/Inference.py:88: RuntimeWarning: divide by zero encountered in log\n",
      "  probs[idx_in] = np.array(np.log(node.p))[cat_data[~out_domain_ids]]\n"
     ]
    }
   ],
   "source": [
    "result = mpe(spn_classification, test_classification)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "id": "ae1975a8-fc74-48e4-86b8-f12c48b437c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred = result[:,-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "id": "03f955ed-c134-4168-966a-6392dd722b94",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8888888888888888\n",
      "0.9027450980392158\n",
      "0.8888888888888888\n",
      "0.8905572238905572\n"
     ]
    }
   ],
   "source": [
    "s_accuracy = accuracy_score(y_test, y_pred)\n",
    "s_precision = precision_score(y_test, y_pred, average='weighted')\n",
    "s_recall = recall_score(y_test, y_pred, average='weighted')\n",
    "s_f1 = f1_score(y_test, y_pred, average='weighted')\n",
    "print(s_accuracy)\n",
    "print(s_precision)\n",
    "print(s_recall)\n",
    "print(s_f1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46ff8b4a-e212-41d9-b4b6-efc7caf03fba",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
